"""
Optical/metric analogue for the kernel‑to‑metric simulation.

We use the weak‑gradient eikonal approximation:
  n(x, y) = 1 + λ · Ĝ(x, y)
and estimate the deflection for a ray at impact parameter b by
integrating along x:
  α(b) ≈ -∫ (∂n/∂y) / n^2  dx
Rays above/below the centre at ±b are averaged in magnitude to reduce noise.
"""

from __future__ import annotations

import numpy as np
from typing import Tuple

from .analysis import lensing_fit


def _compute_row_deflection(n: np.ndarray, dn_dy: np.ndarray, row: int) -> float:
    """Deflection for a single horizontal ray at a given row."""
    integrand = -dn_dy[row, :] / (n[row, :] ** 2)
    # Unit pixel spacing → dx = 1.0
    return float(np.trapz(integrand, dx=1.0))


def compute_deflection_curve(G_hat: np.ndarray,
                             lambd: float,
                             ell: int,
                             L: int,
                             b_min_factor: float = 4.0,
                             b_max_factor: float = 0.3,
                             num_rays: int = 128) -> Tuple[np.ndarray, np.ndarray, float, float, Tuple[float, float]]:
    """
    Compute deflection angles for a range of impact parameters.

    Inputs
    ------
    G_hat : (L,L) normalised gradient magnitude (already divided by its mean).
    lambd : coupling coefficient λ.
    ell   : smoothing width (sets the inner b cut).
    L     : lattice size.

    Returns
    -------
    b_vals, alpha_vals, slope, R2, (b_min, b_max)
    """
    # Build index field from NORMALISED source (see io_fphs.kernel_to_envelope_2d)
    n = 1.0 + float(lambd) * np.asarray(G_hat, dtype=np.float64)
    # Vertical derivative of n
    dn_dy, _dn_dx = np.gradient(n)

    # Impact parameter range (in pixel rows around the centre)
    b_min = max(int(np.ceil(b_min_factor * ell)), 1)
    b_max = int(np.floor(b_max_factor * L))
    if b_max <= b_min:
        b_max = max(b_min + 1, b_min * 2)

    # Sample rows symmetrically about the centre
    c = (L - 1) / 2.0
    b_indices = np.unique(np.linspace(b_min, b_max, num_rays, dtype=int))
    b_vals = b_indices.astype(float)
    alphas = np.empty_like(b_vals, dtype=float)

    for i, b in enumerate(b_indices):
        row_up = int(round(c + b))
        row_dn = int(round(c - b))
        row_up = min(max(row_up, 0), L - 1)
        row_dn = min(max(row_dn, 0), L - 1)
        a_up = _compute_row_deflection(n, dn_dy, row_up)
        a_dn = _compute_row_deflection(n, dn_dy, row_dn)
        # Average magnitudes to avoid sign cancellation under symmetry
        alphas[i] = 0.5 * (abs(a_up) + abs(a_dn))

    slope, R2 = lensing_fit(b_vals, alphas)
    return b_vals, alphas, float(slope), float(R2), (float(b_min), float(b_max))
